{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Encoder Training\n",
    "=============="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pylab notebook\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "import primo.models\n",
    "import primo.datasets\n",
    "\n",
    "import cupyck"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reserve space on the GPU for running simulations. It's important to do this before running any tensorflow code (which will take all available GPU memory):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cupyck_sess = cupyck.GPUSession(max_seqlen=200, nblocks=1000, nthreads=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulator = primo.models.Simulator(cupyck_sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load up the training and validation datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = primo.datasets.OpenImagesTrain(\n",
    "    '/tf/open_images/train/', switch_every=10**5\n",
    ")\n",
    "\n",
    "val_dataset = primo.datasets.OpenImagesVal('/tf/open_images/validation/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keras_batch_generator(dataset_batch_generator, sim_thresh):\n",
    "    while True:\n",
    "        indices, pairs = next(dataset_batch_generator)\n",
    "        distances = np.sqrt(np.square(pairs[:,0,:] - pairs[:,1,:]).sum(1))\n",
    "        similar = (distances < sim_thresh).astype(int)\n",
    "        yield pairs, similar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To see how this value was derived, please consult the Materials and Methods subsection under Feature Extraction section.\n",
    "sim_thresh = 75\n",
    "# Intuitively determined:\n",
    "encoder_train_batch_size = 100\n",
    "encoder_val_batch_size = 2500\n",
    "predictor_train_batch_size = 1000\n",
    "\n",
    "encoder_train_batches = keras_batch_generator(\n",
    "    train_dataset.balanced_pairs(encoder_train_batch_size, sim_thresh),\n",
    "    sim_thresh\n",
    ")\n",
    "\n",
    "encoder_val_batches = keras_batch_generator(\n",
    "    val_dataset.random_pairs(encoder_val_batch_size),\n",
    "    sim_thresh\n",
    ")\n",
    "\n",
    "predictor_train_batches = train_dataset.random_pairs(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the models and stack them together with the trainer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = primo.models.Encoder()\n",
    "yield_predictor = primo.models.Predictor('/tf/primo/data/models/yield-model.h5')\n",
    "encoder_trainer = primo.models.EncoderTrainer(encoder, yield_predictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_trainer.model.compile(tf.keras.optimizers.Adagrad(1e-3), 'binary_crossentropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = encoder_trainer.model.fit_generator(\n",
    "    encoder_train_batches,\n",
    "    steps_per_epoch = 1000,\n",
    "    epochs = 100,\n",
    "    callbacks = [\n",
    "        encoder_trainer.refit_predictor(\n",
    "            predictor_train_batches, simulator, refit_every=1, refit_epochs=10\n",
    "        )\n",
    "    ],\n",
    "    validation_data = encoder_val_batches,\n",
    "    validation_steps = 1,\n",
    "    verbose = 2\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder.save('/tf/primo/data/models/encoder_model.h5')\n",
    "predictor.save('/tf/primo/data/models/predictor_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}